import random
from collections import defaultdict

import numpy as np
import torch
from datasets import Dataset, load_dataset
from tqdm import tqdm

from .Base import BaseDataset, UnlearnDataset


class Pile(BaseDataset):
    def __init__(self, dataset_name, seed=0, ratio=0.1):
        self.dataset_name = dataset_name
        self.seed = seed
        self.ratio = ratio
        self.dataset = defaultdict()
        self.dataset = self.get_dataset()

    def get_dataset(self):
        dataset = defaultdict()
        train_dataset = load_dataset(
            "monology/pile-uncopyrighted",
            cache_dir="./.cache",
            split="train",
            data_files={"train": "train/00.jsonl.zst"},
        )
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)
        train_dataset = train_dataset.shuffle(seed=self.seed)
        train_dataset = train_dataset.select(
            range(int(len(train_dataset) * self.ratio))
        )
        dataset["train"] = train_dataset
        dataset["test"] = load_dataset(
            "monology/pile-uncopyrighted",
            data_files={"test": "val.jsonl.zst"},
            cache_dir="./.cache",
            split="test",
        )

        return dataset

    def __preprocess__(self, tokenizer):
        def preprocess(examples):
            results = {"input_ids": [], "attention_mask": [], "label": []}
            tokenized = tokenizer(
                examples["text"],
                truncation=True,
                padding="max_length",
                add_special_tokens=True,
                max_length=512,
            )
            results["input_ids"] = tokenized.input_ids
            results["attention_mask"] = tokenized.attention_mask
            results["label"] = tokenized.input_ids
            return results

        train_dataset = self.dataset["train"].map(
            preprocess, batched=True, remove_columns=["text"]
        )
        test_dataset = self.dataset["test"].map(
            preprocess, batched=True, remove_columns=["text"]
        )

        train_dataset.set_format(
            type="torch", columns=["input_ids", "attention_mask", "label"]
        )

        test_dataset.set_format(
            type="torch", columns=["input_ids", "attention_mask", "label"]
        )

        self.dataset["train"] = train_dataset
        self.dataset["test"] = test_dataset
        return self.dataset

    def build_dataset(self, tokenizer):
        self.__preprocess__(tokenizer)
        return self.dataset
